Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds restart functionality to the simulation framework, allowing simulations to resume from previously saved state files. The implementation supports loading restart data from both single .part TorchScript files and tar archives containing multiple rank-specific .part files.
Changes:
- Added restart file loading infrastructure with support for tar archives and single .part files
- Modified MeshBlock initialization to accept an optional restart file parameter
- Updated CMake configuration to include libarchive dependency (v3.8.5)
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| src/mesh/meshblock.hpp | Added restart_file parameter to initialize method; updated _init_from_restart signature to accept filename |
| src/mesh/meshblock.cpp | Refactored restart initialization to use new load_restart function; moved tensor device transfer and timing data cleanup into _init_from_restart |
| src/input/read_restart_file.hpp | Simplified interface - removed old helper function, added clean load_restart API |
| src/input/read_restart_file.cpp | New implementation with tar archive support using libarchive; handles rank-based filtering of restart files |
| src/CMakeLists.txt | Added libarchive include directory and library linkage |
| cmake/libarchive.cmake | New CMake configuration for fetching and building libarchive v3.8.5 |
| pyproject.toml | Updated kintera dependency from >=1.2.6 to >=1.2.9 |
| python/csrc/snapy.cpp | Added Python binding for load_restart function |
| python/csrc/pymesh.cpp | Modified initialize binding to accept optional restart_file parameter |
| examples/straka.cpp | Updated to use new initialize API with restart parameter |
| examples/run_hydro.cpp | Updated to use new initialize API with CLI restart filename |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| found_part = true; | ||
| auto out = parse_part_filename(name); | ||
| // find the block rank number after "block" | ||
| int rank = std::stoi(out.blockid.substr(5, out.blockid.size() - 5)); |
There was a problem hiding this comment.
The code assumes blockid starts with "block" (5 characters) and directly uses substr(5, ...) without validation. If the blockid doesn't follow this expected format (e.g., is shorter than 5 characters), this will cause undefined behavior or throw an out_of_range exception. Consider adding validation to check if blockid starts with "block" before extracting the rank number.
| static std::string dtype_to_string(const at::ScalarType t) { | ||
| // at::toString exists in many builds; this is safe enough. | ||
| return std::string(at::toString(t)); | ||
| } | ||
|
|
||
| static std::string device_to_string(const at::Device& d) { | ||
| std::ostringstream oss; | ||
| oss << d; | ||
| return oss.str(); | ||
| } | ||
|
|
||
| static std::string shape_to_string(const at::Tensor& t) { | ||
| std::ostringstream oss; | ||
| oss << "("; | ||
| for (int64_t i = 0; i < t.dim(); ++i) { | ||
| oss << t.size(i); | ||
| if (i + 1 < t.dim()) oss << ", "; | ||
| } | ||
| oss << ")"; | ||
| return oss.str(); | ||
| } |
There was a problem hiding this comment.
These helper functions (dtype_to_string, device_to_string, shape_to_string) are defined but never used in this file. Consider removing them if they're not needed, or add a comment explaining that they're intended for future use or debugging purposes.
| archive_read_free(ar); | ||
| } else { | ||
| // Treat as a single .part TorchScript file | ||
| std::cout << "single .part file detected\n"; |
There was a problem hiding this comment.
This debug output message should either be removed or changed to use a proper logging mechanism instead of directly writing to stdout. Direct console output in library code can interfere with application-level output formatting and logging.
| std::cout << "single .part file detected\n"; |
| void load_restart(Variables& vars, std::string const& path) { | ||
| // Dispatch based on whether `path` is a .part file or a tar archive. | ||
| if (is_tar_archive(path)) { | ||
| struct archive* ar = archive_read_new(); | ||
| if (!ar) { | ||
| std::cerr << path << ": failed to allocate archive reader\n"; | ||
| return; | ||
| } | ||
|
|
||
| archive_read_support_filter_all(ar); | ||
| archive_read_support_format_all(ar); | ||
|
|
||
| int r = archive_read_open_filename(ar, path.c_str(), 10240); | ||
| if (r != ARCHIVE_OK) { | ||
| std::cerr << path | ||
| << ": failed to open archive: " << archive_error_string(ar) | ||
| << "\n"; | ||
| archive_read_free(ar); | ||
| return; | ||
| } | ||
|
|
||
| bool found_part = false; | ||
|
|
||
| struct archive_entry* entry = nullptr; | ||
| while ((r = archive_read_next_header(ar, &entry)) == ARCHIVE_OK) { | ||
| const char* name_c = archive_entry_pathname(entry); | ||
| std::string name = name_c ? std::string{name_c} : std::string{}; | ||
|
|
||
| if (ends_with(name, ".part")) { | ||
| found_part = true; | ||
| auto out = parse_part_filename(name); | ||
| // find the block rank number after "block" | ||
| int rank = std::stoi(out.blockid.substr(5, out.blockid.size() - 5)); | ||
| int my_rank = get_rank(); | ||
| if (rank != my_rank) { | ||
| // Not for this rank; skip | ||
| archive_read_data_skip(ar); | ||
| } else { | ||
| load_pt_from_tar(vars, ar, entry); | ||
| return; | ||
| } | ||
|
|
||
| // Note: consume the entry data (via archive_read_data*) | ||
| // or skip it, otherwise the next header read will misbehave. | ||
| } else { | ||
| // Skip non-.part entries quickly | ||
| archive_read_data_skip(ar); | ||
| } | ||
| } | ||
|
|
||
| if (!found_part) { | ||
| std::cerr << path << ": no .part files found in tar archive\n"; | ||
| } | ||
|
|
||
| if (r != ARCHIVE_EOF && r != ARCHIVE_OK) { | ||
| std::cerr << path | ||
| << ": error while reading archive: " << archive_error_string(ar) | ||
| << "\n"; | ||
| } | ||
|
|
||
| archive_read_close(ar); | ||
| archive_read_free(ar); | ||
| } else { | ||
| // Treat as a single .part TorchScript file | ||
| std::cout << "single .part file detected\n"; | ||
| kintera::load_tensors(vars, path); | ||
| } | ||
| } |
There was a problem hiding this comment.
The load_restart function silently returns on multiple error conditions (allocation failure, file open failure, no .part files found, read errors), leaving vars potentially empty or partially populated. This makes it difficult for callers to distinguish between success and failure. Consider throwing an exception or returning a status code to indicate errors, so that the caller can handle restart failures appropriately.
| load_restart(vars, fname); | ||
|
|
||
| cycle = timing_vars.at("last_cycle").item<int64_t>(); | ||
| cycle = vars.at("last_cycle").item<int64_t>() - 1; |
There was a problem hiding this comment.
The cycle is set to last_cycle - 1, which seems counterintuitive. If the restart file contains the last completed cycle, the next cycle should typically be last_cycle + 1, not last_cycle - 1. This decrement could cause the simulation to repeat a cycle or have an incorrect cycle count. Please verify this is the intended behavior and add a comment explaining the reasoning if this is correct.
| cycle = vars.at("last_cycle").item<int64_t>() - 1; | |
| // last_cycle in the restart file stores the index of the last completed cycle. | |
| // Set cycle to the next cycle to run so that the simulation continues forward. | |
| cycle = vars.at("last_cycle").item<int64_t>() + 1; |
| // C/C++ | ||
| #include <filesystem> | ||
| #include <fstream> | ||
| #include <iomanip> | ||
| #include <iostream> | ||
| #include <optional> | ||
| #include <random> | ||
| #include <sstream> | ||
| #include <stdexcept> | ||
| #include <string> | ||
| #include <tuple> | ||
| #include <vector> | ||
|
|
||
| // archive | ||
| #include <archive.h> | ||
| #include <archive_entry.h> | ||
|
|
||
| // kintera | ||
| #include <kintera/utils/serialize.hpp> | ||
|
|
||
| // snap | ||
| #include <snap/layout/layout.hpp> | ||
|
|
||
| namespace fs = std::filesystem; | ||
|
|
||
| namespace snap { | ||
|
|
||
| // ------------------------- | ||
| // Small helpers | ||
| // ------------------------- | ||
|
|
||
| struct RestartFields { | ||
| std::string basename; | ||
| std::string blockid; | ||
| std::string filenumber; | ||
| }; | ||
|
|
||
| static RestartFields parse_part_filename(const std::string& name) { | ||
| constexpr std::string_view suffix = ".part"; | ||
|
|
||
| if (name.size() <= suffix.size() || | ||
| name.compare(name.size() - suffix.size(), suffix.size(), suffix) != 0) { | ||
| throw std::invalid_argument("filename does not end with .part"); | ||
| } | ||
|
|
||
| // Strip ".part" | ||
| const std::string_view core(name.data(), name.size() - suffix.size()); | ||
|
|
||
| // Find last two dots | ||
| const size_t dot2 = core.rfind('.'); | ||
| if (dot2 == std::string::npos) { | ||
| throw std::invalid_argument("filename missing filenumber field"); | ||
| } | ||
|
|
||
| const size_t dot1 = core.rfind('.', dot2 - 1); | ||
| if (dot1 == std::string::npos) { | ||
| throw std::invalid_argument("filename missing block_id field"); | ||
| } | ||
|
|
||
| RestartFields out; | ||
| out.basename = std::string(core.substr(0, dot1)); | ||
| out.blockid = std::string(core.substr(dot1 + 1, dot2 - dot1 - 1)); | ||
| out.filenumber = std::string(core.substr(dot2 + 1)); | ||
|
|
||
| if (out.basename.empty() || out.blockid.empty() || out.filenumber.empty()) { | ||
| throw std::invalid_argument("one or more filename fields are empty"); | ||
| } | ||
|
|
||
| return out; | ||
| } | ||
|
|
||
| static std::string dtype_to_string(const at::ScalarType t) { | ||
| // at::toString exists in many builds; this is safe enough. | ||
| return std::string(at::toString(t)); | ||
| } | ||
|
|
||
| static std::string device_to_string(const at::Device& d) { | ||
| std::ostringstream oss; | ||
| oss << d; | ||
| return oss.str(); | ||
| } | ||
|
|
||
| static std::string shape_to_string(const at::Tensor& t) { | ||
| std::ostringstream oss; | ||
| oss << "("; | ||
| for (int64_t i = 0; i < t.dim(); ++i) { | ||
| oss << t.size(i); | ||
| if (i + 1 < t.dim()) oss << ", "; | ||
| } | ||
| oss << ")"; | ||
| return oss.str(); | ||
| } | ||
|
|
||
| static bool ends_with(std::string const& s, std::string const& suffix) { | ||
| return s.size() >= suffix.size() && | ||
| s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; | ||
| } | ||
|
|
||
| static bool is_tar_archive(std::string const& path) { | ||
| if (!fs::is_regular_file(path)) return false; | ||
|
|
||
| struct archive* ar = archive_read_new(); | ||
| if (!ar) return false; | ||
|
|
||
| archive_read_support_filter_all(ar); | ||
| archive_read_support_format_all(ar); | ||
|
|
||
| // Try opening as an archive; if it succeeds, treat as tar-like. | ||
| int r = archive_read_open_filename(ar, path.c_str(), 10240); | ||
| if (r != ARCHIVE_OK) { | ||
| archive_read_free(ar); | ||
| return false; | ||
| } | ||
|
|
||
| // Some files might be recognized as other archive formats too; in practice | ||
| // this matches Python's "is_tarfile" intent well. | ||
| archive_read_close(ar); | ||
| archive_read_free(ar); | ||
| return true; | ||
| } | ||
|
|
||
| // Create a unique temp file path; not bulletproof, but good enough | ||
| static fs::path make_temp_path(std::string_view suffix) { | ||
| fs::path dir = fs::temp_directory_path(); | ||
|
|
||
| std::random_device rd; | ||
| std::mt19937_64 gen(rd()); | ||
| std::uniform_int_distribution<uint64_t> dis; | ||
|
|
||
| for (int tries = 0; tries < 20; ++tries) { | ||
| uint64_t r = dis(gen); | ||
| std::ostringstream name; | ||
| name << "tmp_" << std::hex << r << suffix; | ||
| fs::path p = dir / name.str(); | ||
| if (!fs::exists(p)) return p; | ||
| } | ||
|
|
||
| // Fallback (very unlikely to collide) | ||
| return dir / ("tmp_fallback" + std::string(suffix)); | ||
| } | ||
|
|
||
| static void load_pt_from_tar(Variables& vars, struct archive* ar, | ||
| struct archive_entry* entry) { | ||
| const char* name_c = archive_entry_pathname(entry); | ||
| std::string member_name = | ||
| name_c ? std::string{name_c} : std::string{"<unknown>"}; | ||
|
|
||
| // Extract this entry into a temporary file (TorchScript loader prefers | ||
| // real/seekable file) | ||
| fs::path tmp_path = make_temp_path(".part"); | ||
| std::ofstream out(tmp_path, std::ios::binary); | ||
| if (!out) { | ||
| std::cerr << "\n=== " << member_name << " ===\n"; | ||
| std::cerr << " ERROR: could not create temp file: " << tmp_path.string() | ||
| << "\n"; | ||
| // Must still consume/skip entry data: | ||
| archive_read_data_skip(ar); | ||
| return; | ||
| } | ||
|
|
||
| std::vector<char> buf(1 << 20); | ||
| while (true) { | ||
| la_ssize_t n = archive_read_data(ar, buf.data(), buf.size()); | ||
| if (n == 0) break; // end of this entry | ||
| if (n < 0) { | ||
| std::cerr << "\n=== " << member_name << " ===\n"; | ||
| std::cerr << " ERROR: could not extract file from tar: " | ||
| << archive_error_string(ar) << "\n"; | ||
| out.close(); | ||
| std::error_code ec; | ||
| fs::remove(tmp_path, ec); | ||
| return; | ||
| } | ||
| out.write(buf.data(), static_cast<std::streamsize>(n)); | ||
| if (!out) { | ||
| std::cerr << "\n=== " << member_name << " ===\n"; | ||
| std::cerr << " ERROR: failed writing temp file\n"; | ||
| out.close(); | ||
| std::error_code ec; | ||
| fs::remove(tmp_path, ec); | ||
| return; | ||
| } | ||
| } | ||
|
|
||
| out.flush(); | ||
| out.close(); | ||
|
|
||
| // load the extracted .part | ||
| kintera::load_tensors(vars, tmp_path.string()); | ||
|
|
||
| // remove empty tensors (if any) | ||
| for (auto it = vars.begin(); it != vars.end();) { | ||
| if (!it->second.defined() || it->second.numel() == 0) { | ||
| it = vars.erase(it); | ||
| } else { | ||
| ++it; | ||
| } | ||
| } | ||
|
|
||
| // Cleanup | ||
| std::error_code ec; | ||
| fs::remove(tmp_path, ec); | ||
| } | ||
|
|
||
| void load_restart(Variables& vars, std::string const& path) { | ||
| // Dispatch based on whether `path` is a .part file or a tar archive. | ||
| if (is_tar_archive(path)) { | ||
| struct archive* ar = archive_read_new(); | ||
| if (!ar) { | ||
| std::cerr << path << ": failed to allocate archive reader\n"; | ||
| return; | ||
| } | ||
|
|
||
| archive_read_support_filter_all(ar); | ||
| archive_read_support_format_all(ar); | ||
|
|
||
| int r = archive_read_open_filename(ar, path.c_str(), 10240); | ||
| if (r != ARCHIVE_OK) { | ||
| std::cerr << path | ||
| << ": failed to open archive: " << archive_error_string(ar) | ||
| << "\n"; | ||
| archive_read_free(ar); | ||
| return; | ||
| } | ||
|
|
||
| bool found_part = false; | ||
|
|
||
| struct archive_entry* entry = nullptr; | ||
| while ((r = archive_read_next_header(ar, &entry)) == ARCHIVE_OK) { | ||
| const char* name_c = archive_entry_pathname(entry); | ||
| std::string name = name_c ? std::string{name_c} : std::string{}; | ||
|
|
||
| if (ends_with(name, ".part")) { | ||
| found_part = true; | ||
| auto out = parse_part_filename(name); | ||
| // find the block rank number after "block" | ||
| int rank = std::stoi(out.blockid.substr(5, out.blockid.size() - 5)); | ||
| int my_rank = get_rank(); | ||
| if (rank != my_rank) { | ||
| // Not for this rank; skip | ||
| archive_read_data_skip(ar); | ||
| } else { | ||
| load_pt_from_tar(vars, ar, entry); | ||
| return; | ||
| } | ||
|
|
||
| // Note: consume the entry data (via archive_read_data*) | ||
| // or skip it, otherwise the next header read will misbehave. | ||
| } else { | ||
| // Skip non-.part entries quickly | ||
| archive_read_data_skip(ar); | ||
| } | ||
| } | ||
|
|
||
| if (!found_part) { | ||
| std::cerr << path << ": no .part files found in tar archive\n"; | ||
| } | ||
|
|
||
| if (r != ARCHIVE_EOF && r != ARCHIVE_OK) { | ||
| std::cerr << path | ||
| << ": error while reading archive: " << archive_error_string(ar) | ||
| << "\n"; | ||
| } | ||
|
|
||
| archive_read_close(ar); | ||
| archive_read_free(ar); | ||
| } else { | ||
| // Treat as a single .part TorchScript file | ||
| std::cout << "single .part file detected\n"; | ||
| kintera::load_tensors(vars, path); | ||
| } | ||
| } | ||
|
|
||
| } // namespace snap |
There was a problem hiding this comment.
The new restart functionality lacks test coverage. Consider adding tests that verify: (1) loading restart files from tar archives, (2) loading single .part files, (3) correct extraction of timing data (last_time, last_cycle, file_number, next_time), (4) proper handling of missing files or corrupted archives, and (5) correct rank-based filtering of .part files in multi-rank scenarios.
| // Create a unique temp file path; not bulletproof, but good enough | ||
| static fs::path make_temp_path(std::string_view suffix) { | ||
| fs::path dir = fs::temp_directory_path(); | ||
|
|
||
| std::random_device rd; | ||
| std::mt19937_64 gen(rd()); | ||
| std::uniform_int_distribution<uint64_t> dis; | ||
|
|
||
| for (int tries = 0; tries < 20; ++tries) { | ||
| uint64_t r = dis(gen); | ||
| std::ostringstream name; | ||
| name << "tmp_" << std::hex << r << suffix; | ||
| fs::path p = dir / name.str(); | ||
| if (!fs::exists(p)) return p; | ||
| } | ||
|
|
||
| // Fallback (very unlikely to collide) | ||
| return dir / ("tmp_fallback" + std::string(suffix)); | ||
| } |
There was a problem hiding this comment.
The temporary file creation uses a random number generator but doesn't set permissions on the created file. Consider using std::filesystem operations with explicit permissions (e.g., fs::perms::owner_read | fs::perms::owner_write) to prevent other users from accessing potentially sensitive restart data in the temporary directory.
| hydro_w.index(interior)[IPR] = in_vars["press"]; | ||
| } | ||
| using Variables = std::map<std::string, torch::Tensor>; | ||
|
|
There was a problem hiding this comment.
The load_restart function lacks documentation. Consider adding a comment block that describes: (1) the purpose of the function, (2) the expected format of the restart file (tar archive or single .part file), (3) what variables are loaded into the vars map, (4) what happens on error (currently silent failure), and (5) how multi-rank scenarios are handled.
| /** | |
| * Load model/state variables from a restart file into the given map. | |
| * | |
| * The restart file at \p path is expected to be either: | |
| * - a tar archive containing one or more rank-specific `.part` files, or | |
| * - a single `.part` file corresponding to the current rank. | |
| * | |
| * On success, \p vars is populated with named tensors deserialized from | |
| * the restart file. Keys in the map correspond to variable names stored | |
| * in the restart, and values are the associated torch::Tensor objects. | |
| * | |
| * Error handling: | |
| * - The function does not throw or return an explicit error code. | |
| * - I/O or parsing failures typically result in an empty or partially | |
| * populated \p vars map, without additional notification. | |
| * Callers are responsible for validating that the expected variables | |
| * have been loaded (e.g., by checking for required keys). | |
| * | |
| * Multi-rank usage: | |
| * - In multi-rank runs, this function is intended to be called | |
| * independently by each rank. | |
| * - The caller must provide a \p path that resolves to the appropriate | |
| * restart data for that rank (e.g., a rank-specific `.part` file | |
| * within a tar archive or as a standalone file). | |
| */ |
|
🎉 Released v1.2.4! What's Changed
Full Changelog: v1.2.3...v1.2.4 |
No description provided.